"""Script for multi-gpu training."""
import json
import logging
import torch.multiprocessing

torch.multiprocessing.set_sharing_strategy('file_system')
import os
import sys

sys.path.append("/media/liyuke/share/AAA/part3/pose_estimation_master/")
import numpy as np
from eval_ap import validate_pose
import torch.nn as nn
import torch.utils.data
from tensorboardX import SummaryWriter
from tqdm import tqdm
from datasets.air2land import Air2land
from opt import cfg, logger, opt
from utils.logger import board_writing, debug_writing
from utils.metrics import *
from utils.transforms import get_func_heatmap_to_coord
from models.fasterpose import FastPose

num_gpu = torch.cuda.device_count()
valid_batch = 1 * num_gpu
if opt.sync:
    norm_layer = nn.SyncBatchNorm
else:
    norm_layer = nn.BatchNorm2d


def train(opt, train_loader, m, criterion, optimizer, writer):
    loss_logger = DataLogger()
    acc_logger = DataLogger()

    combined_loss = (cfg.LOSS.get('TYPE') == 'Combined')

    m.train()
    norm_type = cfg.LOSS.get('NORM_TYPE', None)

    train_loader = tqdm(train_loader, dynamic_ncols=True)

    for i, (_, inps, labels, label_masks, _, bboxes, _, _, _) in enumerate(train_loader):
        if isinstance(inps, list):
            inps = [inp.cuda().requires_grad_() for inp in inps]
        else:
            inps = inps.cuda().requires_grad_()
        if isinstance(labels, list):
            labels = [label.cuda() for label in labels]
            label_masks = [label_mask.cuda() for label_mask in label_masks]
        else:
            labels = labels.cuda()
            label_masks = label_masks.cuda()

        output = m(inps)

        if cfg.LOSS.get('TYPE') == 'MSELoss':
            loss = 0.5 * criterion(output.mul(label_masks), labels.mul(label_masks))
            acc = calc_accuracy(output.mul(label_masks), labels.mul(label_masks))
        elif cfg.LOSS.get('TYPE') == 'Combined':
            if output.size()[1] == 68:
                face_hand_num = 42
            else:
                face_hand_num = 110

            output_body_foot = output[:, :-face_hand_num, :, :]
            output_face_hand = output[:, -face_hand_num:, :, :]
            num_body_foot = output_body_foot.shape[1]
            num_face_hand = output_face_hand.shape[1]

            label_masks_body_foot = label_masks[0]
            label_masks_face_hand = label_masks[1]

            labels_body_foot = labels[0]
            labels_face_hand = labels[1]

            loss_body_foot = 0.5 * criterion[0](output_body_foot.mul(label_masks_body_foot),
                                                labels_body_foot.mul(label_masks_body_foot))
            acc_body_foot = calc_accuracy(output_body_foot.mul(label_masks_body_foot),
                                          labels_body_foot.mul(label_masks_body_foot))

            loss_face_hand = criterion[1](output_face_hand, labels_face_hand, label_masks_face_hand)
            acc_face_hand = calc_integral_accuracy(output_face_hand, labels_face_hand, label_masks_face_hand,
                                                   output_3d=False, norm_type=norm_type)

            loss_body_foot *= 100
            loss_face_hand *= 0.01

            loss = loss_body_foot + loss_face_hand
            acc = acc_body_foot * num_body_foot / (num_body_foot + num_face_hand) + acc_face_hand * num_face_hand / (
                        num_body_foot + num_face_hand)
        else:
            loss = criterion(output, labels, label_masks)
            acc = calc_integral_accuracy(output, labels, label_masks, output_3d=False, norm_type=norm_type)

        if isinstance(inps, list):
            batch_size = inps[0].size(0)
        else:
            batch_size = inps.size(0)

        loss_logger.update(loss.item(), batch_size)
        acc_logger.update(acc, batch_size)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        opt.trainIters += 1
        # Tensorboard
        if opt.board:
            board_writing(writer, loss_logger.avg, acc_logger.avg, opt.trainIters, 'Train')

        # Debug
        if opt.debug and not i % 10:
            debug_writing(writer, output, labels, inps, opt.trainIters)

        # TQDM
        train_loader.set_description(
            'loss: {loss:.8f} | acc: {acc:.4f}'.format(
                loss=loss_logger.avg,
                acc=acc_logger.avg)
        )

    train_loader.close()

    return loss_logger.avg, acc_logger.avg



def main():
    logger.info('******************************')
    logger.info(opt)
    logger.info('******************************')
    logger.info(cfg)
    logger.info('******************************')

    # Model Initialize
    m = FastPose(cfg["MODEL"])
    m = nn.DataParallel(m).cuda()

    combined_loss = (cfg.LOSS.get('TYPE') == 'Combined')
    if combined_loss:
        pass
        # criterion1 = builder.build_loss(cfg.LOSS.LOSS_1).cuda()
        # criterion2 = builder.build_loss(cfg.LOSS.LOSS_2).cuda()
        # criterion = [criterion1, criterion2]
    else:
        criterion = nn.MSELoss().cuda()

    if cfg.TRAIN.OPTIMIZER == 'adam':
        optimizer = torch.optim.Adam(m.parameters(), lr=cfg.TRAIN.LR)
    elif cfg.TRAIN.OPTIMIZER == 'rmsprop':
        optimizer = torch.optim.RMSprop(m.parameters(), lr=cfg.TRAIN.LR)

    lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
        optimizer, milestones=cfg.TRAIN.LR_STEP, gamma=cfg.TRAIN.LR_FACTOR)

    writer = SummaryWriter('.tensorboard/{}-{}'.format(opt.exp_id, cfg.FILE_NAME))
    train_dataset = Air2land(cfg, mode="train")

    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=cfg.TRAIN.BATCH_SIZE * num_gpu, shuffle=True, num_workers=opt.nThreads)

    heatmap_to_coord = get_func_heatmap_to_coord(cfg)

    opt.trainIters = 0

    for i in range(cfg.TRAIN.BEGIN_EPOCH, cfg.TRAIN.END_EPOCH):
        opt.epoch = i
        current_lr = optimizer.state_dict()['param_groups'][0]['lr']

        logger.info(f'############# Starting Epoch {opt.epoch} | LR: {current_lr} #############')

        # Training
        loss, miou = train(opt, train_loader, m, criterion, optimizer, writer)
        logger.epochInfo('Train', opt.epoch, loss, miou)

        lr_scheduler.step()

        if True or (i + 1) % opt.snapshot == 0:
            # Save checkpoint
            torch.save(m.module.state_dict(), './exp/{}-{}/model_{}.pth'.format(opt.exp_id, cfg.FILE_NAME, opt.epoch))
            # Prediction Test
            with torch.no_grad():
                # gt_AP = validate_gt(m.module, opt, cfg, heatmap_to_coord)
                error = validate_pose(m.module, opt, heatmap_to_coord)
                # logger.info(f'##### Epoch {opt.epoch} | gt mAP: {gt_AP} | rcnn mAP: {rcnn_AP} #####')

        # Time to add DPG
        if i == cfg.TRAIN.DPG_MILESTONE:
            torch.save(m.module.state_dict(), './exp/{}-{}/final.pth'.format(opt.exp_id, cfg.FILE_NAME))
            # Adjust learning rate
            for param_group in optimizer.param_groups:
                param_group['lr'] = cfg.TRAIN.LR
            lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=cfg.TRAIN.DPG_STEP, gamma=0.1)
            # Reset dataset

            train_dataset = Air2land(cfg, mode="train")
            train_loader = torch.utils.data.DataLoader(
                train_dataset, batch_size=cfg.TRAIN.BATCH_SIZE * num_gpu, shuffle=True, num_workers=opt.nThreads)

    torch.save(m.module.state_dict(), './exp/{}-{}/final_DPG.pth'.format(opt.exp_id, cfg.FILE_NAME))


if __name__ == "__main__":
    main()
